from collections import defaultdict
import json
from json import JSONDecodeError
from prettytable import PrettyTable

from mqtt_pwn.models.message import Message
from mqtt_pwn.models.scan import Scan
from mqtt_pwn.models.topic import Topic
from mqtt_pwn.utils import drop_none


class OwnTracksExploit(object):
    """Represents the owntracks exploit"""

    def __init__(self, scan_id):
        self.scan_id = scan_id
        self._data = {}

    def _parse_messages(self):
        """Parses the messages according to the owntracks message format"""
        messages = Message \
            .select(Message, Topic) \
            .join(Scan, on=(Scan.id == Message.scan)) \
            .join(Topic, on=(Topic.id == Message.topic)) \
            .where(
                Topic.name.startswith('owntracks/'),
                Scan.id == self.scan_id
            )

        data = defaultdict(set)

        for m in messages:
            name = OwnTracksExploit.label_to_name(m.topic.name)
            body = m.body

            if name is None:
                continue

            data[name].add(body)

        for name, body in data.items():
            converted = [
                OwnTracksExploit.body_to_json(b)
                for b
                in body
            ]

            self._data[name] = drop_none(converted)

    @staticmethod
    def label_to_name(label):
        """Converts a label to the name """

        _, _, t = label.partition('owntracks/')
        t = t.split('/')

        if len(t) != 2:
            return None

        return t[0], t[1]

    @staticmethod
    def body_to_json(body):
        """Converts the body of a message to json"""
        try:
            return json.loads(body)
        except JSONDecodeError:
            return None

    def create_urls_table(self):
        """Creates the URLs table from the messages"""

        if len(self._data) == 0:
            self._parse_messages()

        t = PrettyTable(field_names=[
            'User', 'Device', '# Coords'
        ])
        t.align['URL'] = "l"

        for row in self._create_all_urls():
            t.add_row(row)

        return t

    def _create_google_maps_url(self, coordinations: list):
        """Creates the google maps URL from the messages"""

        base_url = "https://www.google.com/maps/dir"
        parts = [base_url]

        for coord in coordinations:
            lat = coord.get('lat', 0)
            lon = coord.get('lon', 0)

            parts.append(f'{lat},{lon}')

        return '/'.join(parts), len(parts) - 1

    def _create_all_urls(self):
        """Creates all the URLs"""

        for (user, device), coords in self._data.items():
            url, num_of_coords = self._create_google_maps_url(coords)
            yield user, device, num_of_coords

    def google_maps_url(self, user=None, device=None):
        """The public method to create a google maps URL"""

        if len(self._data) == 0:
            self._parse_messages()

        if not user and not device:
            return self._create_all_urls()

        identifier = (user, device)
        coords = self._data.get(identifier)

        if not coords:
            return None

        url, num_of_coords = self._create_google_maps_url(coords)
        return url

